{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1b56860-db41-4829-b395-176e11987cdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install daft\n",
    "%pip install Pillow torch torchvision"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c571e01d",
   "metadata": {},
   "source": [
    "```{hint}\n",
    "✨✨✨ **Run this notebook on Google Colab** ✨✨✨\n",
    "\n",
    "You can [run this notebook yourself with Google Colab](https://colab.research.google.com/github/Eventual-Inc/Daft/blob/main/tutorials/mnist.ipynb)!\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b14abf5-a183-4bfb-9b15-a9a54b744fce",
   "metadata": {},
   "source": [
    "# MNIST Daft Tutorial\n",
    "\n",
    "The MNIST Dataset is a \"large database of handwritten digits that is commonly used for training various image processing systems\"."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "252b5128-99c2-49dd-b624-6e4b21275959",
   "metadata": {},
   "source": [
    "## Loading Data\n",
    "\n",
    "This is a JSON file containing all the data for the MNIST test set. Let's load it up into a Daft Dataframe!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fc63a3ad-0e0a-4ab3-9cc0-cbec8bdd0632",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-04-21 11:44:02.554 | INFO     | daft.context:runner:88 - Using PyRunner\n"
     ]
    }
   ],
   "source": [
    "import daft\n",
    "from daft import DataType, col, udf\n",
    "\n",
    "URL = \"https://github.com/Eventual-Inc/mnist-json/raw/master/mnist_handwritten_test.json.gz\"\n",
    "images_df = daft.read_json(URL)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d52f6032-6619-4682-8305-2ed65bdc194c",
   "metadata": {},
   "source": [
    "To peek at the dataset, simply have your notebook display the images_df that was just created."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "73a71adf-3b2e-4ec5-a0d2-34ad8eec734c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "    <table class=\"dataframe\">\n",
       "<tbody>\n",
       "<tr><td>image<br>List[Int64]</td><td>label<br>Int64</td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    <small>(No data to display: Dataframe not materialized)</small>\n",
       "</div>"
      ],
      "text/plain": [
       "+---------------+---------+\n",
       "| image         | label   |\n",
       "| List[Int64]   | Int64   |\n",
       "+===============+=========+\n",
       "+---------------+---------+\n",
       "(No data to display: Dataframe not materialized)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "images_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4787caab-d7d1-4fd4-9a76-ffb08a404a31",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "    <table class=\"dataframe\">\n",
       "<thead>\n",
       "<tr><th>image<br>List[Int64]                                        </th><th style=\"text-align: right;\">  label<br>Int64</th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               7</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               2</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               1</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               0</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               4</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               1</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               4</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               9</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               5</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               9</td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    <small>(Showing first 10 rows)</small>\n",
       "</div>"
      ],
      "text/plain": [
       "+----------------------+---------+\n",
       "| image                |   label |\n",
       "| List[Int64]          |   Int64 |\n",
       "+======================+=========+\n",
       "| [0, 0, 0, 0, 0, 0,   |       7 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         |\n",
       "| 0, 0, 0, 0, 0, 0,... |         |\n",
       "+----------------------+---------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       2 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         |\n",
       "| 0, 0, 0, 0, 0, 0,... |         |\n",
       "+----------------------+---------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       1 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         |\n",
       "| 0, 0, 0, 0, 0, 0,... |         |\n",
       "+----------------------+---------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       0 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         |\n",
       "| 0, 0, 0, 0, 0, 0,... |         |\n",
       "+----------------------+---------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       4 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         |\n",
       "| 0, 0, 0, 0, 0, 0,... |         |\n",
       "+----------------------+---------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       1 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         |\n",
       "| 0, 0, 0, 0, 0, 0,... |         |\n",
       "+----------------------+---------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       4 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         |\n",
       "| 0, 0, 0, 0, 0, 0,... |         |\n",
       "+----------------------+---------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       9 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         |\n",
       "| 0, 0, 0, 0, 0, 0,... |         |\n",
       "+----------------------+---------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       5 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         |\n",
       "| 0, 0, 0, 0, 0, 0,... |         |\n",
       "+----------------------+---------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       9 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         |\n",
       "| 0, 0, 0, 0, 0, 0,... |         |\n",
       "+----------------------+---------+\n",
       "(Showing first 10 rows)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "images_df.show(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "426f1bbb-e1c0-4fd6-b84e-cbb1ab309ff9",
   "metadata": {},
   "source": [
    "You just loaded your first DaFt Dataframe! It consists of two columns:\n",
    "1. The \"image\" column is a Python column of type `list` - where it looks like each row contains a list of digits representing the pixels of each image\n",
    "2. The \"label\" column is an Integer column, consisting of just the label of that image."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a7872e3-9860-4867-8a8c-61a69f69e334",
   "metadata": {},
   "source": [
    "## Processing Columns with User-Defined Functions (UDF)\n",
    "\n",
    "It seems our JSON file has provided us with a one-dimensional array of pixels instead of two-dimensional images. We can easily modify data in this column by instructing Daft to run a method on every row in the column like so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "af857589-b28a-4ee0-91cd-dc7a01ff4c07",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "images_df = images_df.with_column(\n",
    "    \"image_2d\",\n",
    "    col(\"image\").apply(lambda img: np.array(img).reshape(28, 28), return_dtype=DataType.python()),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d1212a7e-949a-4881-ba54-9d7e7eb31e6f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "    <table class=\"dataframe\">\n",
       "<thead>\n",
       "<tr><th>image<br>List[Int64]                                        </th><th style=\"text-align: right;\">  label<br>Int64</th><th>image_2d<br>Python                               </th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               7</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               2</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               1</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               0</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               4</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               1</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               4</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               9</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               5</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               9</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    <small>(Showing first 10 rows)</small>\n",
       "</div>"
      ],
      "text/plain": [
       "+----------------------+---------+----------------------+\n",
       "| image                |   label | image_2d             |\n",
       "| List[Int64]          |   Int64 | Python               |\n",
       "+======================+=========+======================+\n",
       "| [0, 0, 0, 0, 0, 0,   |       7 | [[  0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... |\n",
       "+----------------------+---------+----------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       2 | [[  0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... |\n",
       "+----------------------+---------+----------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       1 | [[  0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... |\n",
       "+----------------------+---------+----------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       0 | [[  0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... |\n",
       "+----------------------+---------+----------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       4 | [[  0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... |\n",
       "+----------------------+---------+----------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       1 | [[  0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... |\n",
       "+----------------------+---------+----------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       4 | [[  0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... |\n",
       "+----------------------+---------+----------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       9 | [[  0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... |\n",
       "+----------------------+---------+----------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       5 | [[  0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... |\n",
       "+----------------------+---------+----------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       9 | [[  0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... |\n",
       "+----------------------+---------+----------------------+\n",
       "(Showing first 10 rows)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "images_df.show(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd0d2664-12d8-4964-85cd-a67f8fee1384",
   "metadata": {},
   "source": [
    "Great, but we can do one better - let's convert these two-dimensional arrays into Images. Computers speak in pixels and arrays, but humans do much better with visual patterns!\n",
    "\n",
    "To do this, we can leverage the `.apply` expression method. Similar to the `.as_py` method, this allows us to run a single function on all rows of a given column, but provides us with more flexibility as it takes as input any arbitrary function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e585303a-7c83-4a31-afbb-461c951481f7",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "\n",
    "images_df = images_df.with_column(\n",
    "    \"pil_image\",\n",
    "    col(\"image_2d\").apply(lambda arr: Image.fromarray(arr.astype(np.uint8)), return_dtype=DataType.python()),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "59b655ed-13aa-4764-acd4-a00beb91ec2f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "    <table class=\"dataframe\">\n",
       "<thead>\n",
       "<tr><th>image<br>List[Int64]                                        </th><th style=\"text-align: right;\">  label<br>Int64</th><th>image_2d<br>Python                               </th><th>pil_image<br>Python                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 </th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               7</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APA4Lea6nSC3hkmmc4SONSzMfQAdadc2dzZSmK7tpoJB/BKhU/kahoq1pupXuj6hDf6dcyW13CSY5YzhlyCDj8CRXXWvxe8b20SxtrH2lVOQbqCOU9OmWUn/APVXUfEfxBqCfDzSNJ16S2uNd1JxqEqpbohtIMYjQbQBlsEnv1HpXj9Fdx8OvDNlqNxe+IdeVh4e0VPPucLnznyNkQ/3j1/LjOa57xPr9z4n8R3usXQ2vcSZVB0jQcKo9gABWRRXSxeOdXt/A0nhGAW0WnSzGaZ1j/ey8g7SxOMZA6AHjrXNUV//2Q==\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FC36F50>\" />                                                    </td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               2</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APAY42lkWNBl2IVR6k17BB8L/BemalZaB4h8TXr+I7oKhtNNi3rDI3TcdpyB36HvwK8z8UaL/wAI74o1LRxOs4s52iEg/iAPGff196yaK9d+H1lbeCPCdz8SNZjWW5bdBpFtJw0jn5TIDn/eHToCe4ryvUb+51XUrnULyQyXNzK0srnuzHJqtRX0J4utvBHxCXSLez+INlo+m2dqqQ2EkQCIf7xLMoB24XB9Pc1wp8D+AdMc/wBrfEaCbazDy9Os2l3YHGHBIHPtj3riNfg0a31V00G8ubuw2qUkuYhG+ccgge9ZlFFFf//Z\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FC37040>\" />                                                </td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               1</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+lALEAAkngAd62PFGgjw1rR0trrz5o4Ynn/d7PLkZAzR9TnbnGfboKxqK634a6PHrPjrT1uAhs7Mm9ut/TyovmOfrgD8awte1M614h1LVCGX7ZcyThWOSoZiQM+2cVn0V6J4RFvo/wu8X63NKi3F6qaVarxuYsQ0nvjaR+XevO6KKKKK//9k=\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FC36920>\" />                                                                                                                                                                                </td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               0</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+u6+Hnwx1Lx9cTlZmsLGGMn7W8JdWfOAijIz3yc8Y+lcdqFjPpt9NaXMbJJE5U7lIzg4yM9qrVJbwtcXMUKqzNI4QBRkkk44r6F8c/Em3+GdnD4I8KWyi5sYVWS4lXiEsA/TGHZg24nplu5zjC11T8SPgyfFupp5WtaK7RG4Cqi3SFl4J46AjA9QcferxOtDQr9NK8Q6ZqMiF0tLuKdlHUhXDEfpX0VqHgHwJ461y88a3niUy2Ny0P7uKdIY0KoqbXZgTzt6fKetcF8VfiLpOo6Ta+E/CStBpNsx+0MkflpKRjaqgdVzkkkcnB9z5FRRRRX//2Q==\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FC36FE0>\" />                            </td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               4</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+tC60W7s9E0/Vpggtr95kgGTuPl7QxPGMZbA57Gohpl0dHOqhAbRZxbswYZDldwBHuAefY1Uq/oenDWNf07TDIIxd3McBc/w7mAz+tb3xG1NLvxZcadZxmHS9JJsLKDGAiIcM2PVm3MSeeea2PElpH4a+Enh/SWTN7rVwdXmc8FIwmyNfoQxP5151T4ZpLeeOeF2jljYOjqcFWByCK7GT4qeKJRvkk097njN02nQNMSO5YpyccV0Hx4u3uPFOipI/72PRoTNGDgRyFnJG3+E9OPTFeV0UVJPPNczvPcSvLM53PJIxZmPqSetR1//Z\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FC37010>\" />                                        </td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               1</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+nIjSOqIpZmOAB1Jrb8V+EtT8GarFpuq+R9pkgSfEMm8KGzwfQgggj8sjBrCorvvg3oI1z4kaeZULW1jm8lO3IGz7uf8AgW2sTx54iPirxvqurhswzTFYOMful+VOO3ygE+5Nc5RXqvgG4Hh34TeM/EUIH22Yx6dC4bDRhupH/fYP/Aa8qooqVLidIJIEmkWGQgvGGIViOmR0OKior//Z\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FC37070>\" />                                                                                                                                                </td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               4</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn8DJwOtbWveF9R8NxWB1PyYp7yHz1tRJmWJD08xf4SRyAefyNVtL0PUNZivpbOJWisYDcXEjuEWNB6knqTwB1JrOrc8HXGkWnjDS7jXU36XHOGuBgnjsSByQDg8eneu68W+F9I8QeIbrXm+I+iTR30jSjzN/moucBPLUEjCgAA4zjgVB8RYbPwLo1v4D0m6FxM7C71e52KGlfH7uPjkKvzHB/vCvMKK734O6JFrPxCtmnjEsOnxPfNDs3GUpjaoHruZT+Fc94ui1ZfE15PrcSw6jdSG4miDKTGXJO0gH5T7Hkd6w6Kmtru5spvOtbiWCXBXfE5VsHgjIqEkk5JyaK//9k=\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FC370A0>\" />            </td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               9</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+tODw9q9zod1rUOnztplqVE11twiksFABPU5I6ZxmsyiivqLRLvU9Lm0Tw8ui2EngQ6RHJdX8ynyjuQu7MzYXlv4SP4sn28K+JegweHPH+qafaW8kFnvEturEEbGGflx/DkkD6c81yVFasniXWpfD8egyancNpUb70tS3yA/4c5x0rsdQuB48+G8V40gbXvDSeXcA8vcWZICv6koSAfQHPevOaKKuabqt9o9xJPYXDQSSwvBJgAh43GGUg8EEf0PUVTr/2Q==\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FC37130>\" />                                                                                            </td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               5</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+tnw34W1nxbqJsNFszczqu9xvVQi5xkliBjmvRV/Z78QxRxPe6zo9qHzkNKxKnB/2QDzjoe9eWarp76TrF7p0kscr2lxJA0kRyrlGKkqfQ44rp/hh4Qi8aeMY7C53m0gha6nRDhnRSBtB7ZLAZ96k8Saj4ssfFl/fw6fqHh5yFiWC0RoRFCMbEyoAIwBz3PNcjeXF5cXDNfTTyzg4YzsWbPvnmq9XtI1nUtA1KPUNKvJbS7j+7JGcHHoR0I9jwa9/+FfxZvtfhv8AStf1ezj1bCtYz3MQVH4wVYKVyc4PUHk+lcJ8cvEeheIPE1l/Y8sNzNawGO6uoEwkjZyAD3xz69epry2iiiiv/9k=\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FC34970>\" /></td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               9</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+pIYJbiVYoInlkbgIilifwFSXlheadP5F9aT2s2M+XPGUb8iKr1f0S0s77XLG01C7FnZzTKk1wf8AlmpPLV9EXGh6zpyLpvwzTw7pljMqhdRa7WW5us4O4HDcdR39sVL47u08P/De9sPG2qW+p6jd2qx2dsi/OJlQKZA2ASu8CTJAwcjJzXzHRXtHwx0jTvBnhC6+JHiG3DupMelQswzI3IyB2JOQCegDHGOa8q8Qa9feJddu9W1CUvcXEhbGSQgzwq56ADgCsyiug1zxnq3iHQ9H0e9eIWekw+VAkSld3AAZueWwAM8fqa5+iv/Z\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FC37610>\" />                                        </td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    <small>(Showing first 10 rows)</small>\n",
       "</div>"
      ],
      "text/plain": [
       "+----------------------+---------+----------------------+------------------+\n",
       "| image                |   label | image_2d             | pil_image        |\n",
       "| List[Int64]          |   Int64 | Python               | Python           |\n",
       "+======================+=========+======================+==================+\n",
       "| [0, 0, 0, 0, 0, 0,   |       7 | [[  0   0   0   0    | <PIL.Image.Image |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |\n",
       "|                      |         |                      | 0x10FC36F50>     |\n",
       "+----------------------+---------+----------------------+------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       2 | [[  0   0   0   0    | <PIL.Image.Image |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |\n",
       "|                      |         |                      | 0x10FC37040>     |\n",
       "+----------------------+---------+----------------------+------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       1 | [[  0   0   0   0    | <PIL.Image.Image |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |\n",
       "|                      |         |                      | 0x10FC36920>     |\n",
       "+----------------------+---------+----------------------+------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       0 | [[  0   0   0   0    | <PIL.Image.Image |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |\n",
       "|                      |         |                      | 0x10FC36FE0>     |\n",
       "+----------------------+---------+----------------------+------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       4 | [[  0   0   0   0    | <PIL.Image.Image |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |\n",
       "|                      |         |                      | 0x10FC37010>     |\n",
       "+----------------------+---------+----------------------+------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       1 | [[  0   0   0   0    | <PIL.Image.Image |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |\n",
       "|                      |         |                      | 0x10FC37070>     |\n",
       "+----------------------+---------+----------------------+------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       4 | [[  0   0   0   0    | <PIL.Image.Image |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |\n",
       "|                      |         |                      | 0x10FC370A0>     |\n",
       "+----------------------+---------+----------------------+------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       9 | [[  0   0   0   0    | <PIL.Image.Image |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |\n",
       "|                      |         |                      | 0x10FC37130>     |\n",
       "+----------------------+---------+----------------------+------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       5 | [[  0   0   0   0    | <PIL.Image.Image |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |\n",
       "|                      |         |                      | 0x10FC34970>     |\n",
       "+----------------------+---------+----------------------+------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       9 | [[  0   0   0   0    | <PIL.Image.Image |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |\n",
       "|                      |         |                      | 0x10FC37610>     |\n",
       "+----------------------+---------+----------------------+------------------+\n",
       "(Showing first 10 rows)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "images_df.show(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6b633f4-3d9d-4c25-9075-bc815d8e357f",
   "metadata": {},
   "source": [
    "Amazing! This looks great and we can finally get some idea of what the dataset truly looks like."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd7e6774-9fb7-4827-a324-c116c8c812e1",
   "metadata": {},
   "source": [
    "## Running a model with UDFs\n",
    "\n",
    "Next, let's try to run a deep learning model to classify each image. Models are expensive to initialize and load, so we want to do this as few times as possible, and share a model across multiple invocations.\n",
    "\n",
    "For the convenience of this quickstart tutorial, we pre-trained a model using a PyTorch-provided example script and saved the trained weights at https://github.com/Eventual-Inc/mnist-json/raw/master/mnist_cnn.pt.  We need to define the same deep learning model \"scaffold\" as the trained model that we want to load (this part is all PyTorch and is not specific at all to DaFt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5ff43066-8a42-4773-974f-160ca4a9bc49",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "###\n",
    "# Model was trained using a script provided in PyTorch Examples: https://github.com/pytorch/examples/blob/main/mnist/main.py\n",
    "###\n",
    "\n",
    "import torch\n",
    "import torch.hub\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
    "        self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
    "        self.dropout1 = nn.Dropout(0.25)\n",
    "        self.dropout2 = nn.Dropout(0.5)\n",
    "        self.fc1 = nn.Linear(9216, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = self.dropout1(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.dropout2(x)\n",
    "        x = self.fc2(x)\n",
    "        output = F.log_softmax(x, dim=1)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "266c1cf8-bf9a-4990-8182-97b072f15b57",
   "metadata": {},
   "source": [
    "Now comes the fun part - we can define a UDF using the `@udf` decorator. Notice that for a batch of data we only initialize our model once!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fda097ea-4946-483c-bcc0-5271e0b033c3",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "@udf(return_dtype=DataType.int64())\n",
    "class ClassifyImages:\n",
    "    def __init__(self):\n",
    "        # Perform expensive initializations - create the model, download model weights and load up the model with weights\n",
    "        self.model = Net()\n",
    "        state_dict = torch.hub.load_state_dict_from_url(\n",
    "            \"https://github.com/Eventual-Inc/mnist-json/raw/master/mnist_cnn.pt\"\n",
    "        )\n",
    "        self.model.load_state_dict(state_dict)\n",
    "\n",
    "    def __call__(self, images_2d_col):\n",
    "        images_arr = np.array(images_2d_col.to_pylist())\n",
    "        normalized_image_2d = images_arr / 255\n",
    "        normalized_image_2d = normalized_image_2d[:, np.newaxis, :, :]\n",
    "        classifications = self.model(torch.from_numpy(normalized_image_2d).float())\n",
    "        return classifications.detach().numpy().argmax(axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3605d3a6-f9ce-4e81-9e0f-5190f981bbd4",
   "metadata": {},
   "source": [
    "Using this UDF is really easy, we simply run it on the columns that we want to process:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "4f9fd9f8-a231-44fb-a519-0288f670a34a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "    <table class=\"dataframe\">\n",
       "<thead>\n",
       "<tr><th>image<br>List[Int64]                                        </th><th style=\"text-align: right;\">  label<br>Int64</th><th>image_2d<br>Python                               </th><th>pil_image<br>Python                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 </th><th style=\"text-align: right;\">  model_classification<br>Int64</th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               7</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APA4Lea6nSC3hkmmc4SONSzMfQAdadc2dzZSmK7tpoJB/BKhU/kahoq1pupXuj6hDf6dcyW13CSY5YzhlyCDj8CRXXWvxe8b20SxtrH2lVOQbqCOU9OmWUn/APVXUfEfxBqCfDzSNJ16S2uNd1JxqEqpbohtIMYjQbQBlsEnv1HpXj9Fdx8OvDNlqNxe+IdeVh4e0VPPucLnznyNkQ/3j1/LjOa57xPr9z4n8R3usXQ2vcSZVB0jQcKo9gABWRRXSxeOdXt/A0nhGAW0WnSzGaZ1j/ey8g7SxOMZA6AHjrXNUV//2Q==\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FCCA800>\" />                                                    </td><td style=\"text-align: right;\">                              7</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               2</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APAY42lkWNBl2IVR6k17BB8L/BemalZaB4h8TXr+I7oKhtNNi3rDI3TcdpyB36HvwK8z8UaL/wAI74o1LRxOs4s52iEg/iAPGff196yaK9d+H1lbeCPCdz8SNZjWW5bdBpFtJw0jn5TIDn/eHToCe4ryvUb+51XUrnULyQyXNzK0srnuzHJqtRX0J4utvBHxCXSLez+INlo+m2dqqQ2EkQCIf7xLMoB24XB9Pc1wp8D+AdMc/wBrfEaCbazDy9Os2l3YHGHBIHPtj3riNfg0a31V00G8ubuw2qUkuYhG+ccgge9ZlFFFf//Z\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FCCBFD0>\" />                                                </td><td style=\"text-align: right;\">                              2</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               1</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+lALEAAkngAd62PFGgjw1rR0trrz5o4Ynn/d7PLkZAzR9TnbnGfboKxqK634a6PHrPjrT1uAhs7Mm9ut/TyovmOfrgD8awte1M614h1LVCGX7ZcyThWOSoZiQM+2cVn0V6J4RFvo/wu8X63NKi3F6qaVarxuYsQ0nvjaR+XevO6KKKKK//9k=\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FCC9A50>\" />                                                                                                                                                                                </td><td style=\"text-align: right;\">                              1</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               0</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+u6+Hnwx1Lx9cTlZmsLGGMn7W8JdWfOAijIz3yc8Y+lcdqFjPpt9NaXMbJJE5U7lIzg4yM9qrVJbwtcXMUKqzNI4QBRkkk44r6F8c/Em3+GdnD4I8KWyi5sYVWS4lXiEsA/TGHZg24nplu5zjC11T8SPgyfFupp5WtaK7RG4Cqi3SFl4J46AjA9QcferxOtDQr9NK8Q6ZqMiF0tLuKdlHUhXDEfpX0VqHgHwJ461y88a3niUy2Ny0P7uKdIY0KoqbXZgTzt6fKetcF8VfiLpOo6Ta+E/CStBpNsx+0MkflpKRjaqgdVzkkkcnB9z5FRRRRX//2Q==\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FCC99C0>\" />                            </td><td style=\"text-align: right;\">                              0</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               4</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+tC60W7s9E0/Vpggtr95kgGTuPl7QxPGMZbA57Gohpl0dHOqhAbRZxbswYZDldwBHuAefY1Uq/oenDWNf07TDIIxd3McBc/w7mAz+tb3xG1NLvxZcadZxmHS9JJsLKDGAiIcM2PVm3MSeeea2PElpH4a+Enh/SWTN7rVwdXmc8FIwmyNfoQxP5151T4ZpLeeOeF2jljYOjqcFWByCK7GT4qeKJRvkk097njN02nQNMSO5YpyccV0Hx4u3uPFOipI/72PRoTNGDgRyFnJG3+E9OPTFeV0UVJPPNczvPcSvLM53PJIxZmPqSetR1//Z\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FCCAA10>\" />                                        </td><td style=\"text-align: right;\">                              4</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               1</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+nIjSOqIpZmOAB1Jrb8V+EtT8GarFpuq+R9pkgSfEMm8KGzwfQgggj8sjBrCorvvg3oI1z4kaeZULW1jm8lO3IGz7uf8AgW2sTx54iPirxvqurhswzTFYOMful+VOO3ygE+5Nc5RXqvgG4Hh34TeM/EUIH22Yx6dC4bDRhupH/fYP/Aa8qooqVLidIJIEmkWGQgvGGIViOmR0OKior//Z\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FCCAC80>\" />                                                                                                                                                </td><td style=\"text-align: right;\">                              1</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               4</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn8DJwOtbWveF9R8NxWB1PyYp7yHz1tRJmWJD08xf4SRyAefyNVtL0PUNZivpbOJWisYDcXEjuEWNB6knqTwB1JrOrc8HXGkWnjDS7jXU36XHOGuBgnjsSByQDg8eneu68W+F9I8QeIbrXm+I+iTR30jSjzN/moucBPLUEjCgAA4zjgVB8RYbPwLo1v4D0m6FxM7C71e52KGlfH7uPjkKvzHB/vCvMKK734O6JFrPxCtmnjEsOnxPfNDs3GUpjaoHruZT+Fc94ui1ZfE15PrcSw6jdSG4miDKTGXJO0gH5T7Hkd6w6Kmtru5spvOtbiWCXBXfE5VsHgjIqEkk5JyaK//9k=\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FCCACB0>\" />            </td><td style=\"text-align: right;\">                              4</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               9</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+tODw9q9zod1rUOnztplqVE11twiksFABPU5I6ZxmsyiivqLRLvU9Lm0Tw8ui2EngQ6RHJdX8ynyjuQu7MzYXlv4SP4sn28K+JegweHPH+qafaW8kFnvEturEEbGGflx/DkkD6c81yVFasniXWpfD8egyancNpUb70tS3yA/4c5x0rsdQuB48+G8V40gbXvDSeXcA8vcWZICv6koSAfQHPevOaKKuabqt9o9xJPYXDQSSwvBJgAh43GGUg8EEf0PUVTr/2Q==\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FCCA9E0>\" />                                                                                            </td><td style=\"text-align: right;\">                              9</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               5</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+tnw34W1nxbqJsNFszczqu9xvVQi5xkliBjmvRV/Z78QxRxPe6zo9qHzkNKxKnB/2QDzjoe9eWarp76TrF7p0kscr2lxJA0kRyrlGKkqfQ44rp/hh4Qi8aeMY7C53m0gha6nRDhnRSBtB7ZLAZ96k8Saj4ssfFl/fw6fqHh5yFiWC0RoRFCMbEyoAIwBz3PNcjeXF5cXDNfTTyzg4YzsWbPvnmq9XtI1nUtA1KPUNKvJbS7j+7JGcHHoR0I9jwa9/+FfxZvtfhv8AStf1ezj1bCtYz3MQVH4wVYKVyc4PUHk+lcJ8cvEeheIPE1l/Y8sNzNawGO6uoEwkjZyAD3xz69epry2iiiiv/9k=\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FCCAF50>\" /></td><td style=\"text-align: right;\">                              6</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               9</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+pIYJbiVYoInlkbgIilifwFSXlheadP5F9aT2s2M+XPGUb8iKr1f0S0s77XLG01C7FnZzTKk1wf8AlmpPLV9EXGh6zpyLpvwzTw7pljMqhdRa7WW5us4O4HDcdR39sVL47u08P/De9sPG2qW+p6jd2qx2dsi/OJlQKZA2ASu8CTJAwcjJzXzHRXtHwx0jTvBnhC6+JHiG3DupMelQswzI3IyB2JOQCegDHGOa8q8Qa9feJddu9W1CUvcXEhbGSQgzwq56ADgCsyiug1zxnq3iHQ9H0e9eIWekw+VAkSld3AAZueWwAM8fqa5+iv/Z\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x10FCCAD40>\" />                                        </td><td style=\"text-align: right;\">                              9</td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    <small>(Showing first 10 rows)</small>\n",
       "</div>"
      ],
      "text/plain": [
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| image                |   label | image_2d             | pil_image        |   model_classification |\n",
       "| List[Int64]          |   Int64 | Python               | Python           |                  Int64 |\n",
       "+======================+=========+======================+==================+========================+\n",
       "| [0, 0, 0, 0, 0, 0,   |       7 | [[  0   0   0   0    | <PIL.Image.Image |                      7 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x10FCCA800>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       2 | [[  0   0   0   0    | <PIL.Image.Image |                      2 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x10FCCBFD0>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       1 | [[  0   0   0   0    | <PIL.Image.Image |                      1 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x10FCC9A50>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       0 | [[  0   0   0   0    | <PIL.Image.Image |                      0 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x10FCC99C0>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       4 | [[  0   0   0   0    | <PIL.Image.Image |                      4 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x10FCCAA10>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       1 | [[  0   0   0   0    | <PIL.Image.Image |                      1 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x10FCCAC80>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       4 | [[  0   0   0   0    | <PIL.Image.Image |                      4 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x10FCCACB0>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       9 | [[  0   0   0   0    | <PIL.Image.Image |                      9 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x10FCCA9E0>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       5 | [[  0   0   0   0    | <PIL.Image.Image |                      6 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x10FCCAF50>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       9 | [[  0   0   0   0    | <PIL.Image.Image |                      9 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x10FCCAD40>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "(Showing first 10 rows)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classified_images_df = images_df.with_column(\"model_classification\", ClassifyImages(col(\"image_2d\")))\n",
    "\n",
    "classified_images_df.show(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e6fb5fc-957d-414b-bfac-961ea64dad68",
   "metadata": {},
   "source": [
    "Our model ran successfully, and produced a new classification column. These look pretty good - let's filter our Dataframe to show only rows that the model predicted wrongly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "69344d63-7db4-496f-a0b2-949dfd947e4f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "    <table class=\"dataframe\">\n",
       "<thead>\n",
       "<tr><th>image<br>List[Int64]                                        </th><th style=\"text-align: right;\">  label<br>Int64</th><th>image_2d<br>Python                               </th><th>pil_image<br>Python                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         </th><th style=\"text-align: right;\">  model_classification<br>Int64</th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               2</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+u90r4YTvo8eseJNasfDthKN0P2vLTyr2ZYhzg8988dMUal8PdNfw5f634X8V2+uQaaA17G1pJayRqxABUNnd3z06dzxXBVe0fUv7I1i01EWtvdG2kEghuF3RuR2YdxT9b1zUvEWqS6hqd1JcXEhPLsSFGchVB6KOwruGsn8A/Da/i1HfHrPiaONYbXjMNsrBi7g9C3Kgdq82rqfAPhCPxp4hOnT6pBp0KRGV5ZSMsNwG1QSMsd1enXEHwf8Ah3rlsVkvNa1CKYElZFmS3PTLYwhxyccnIrkPHOm+G5r+XxAfHH9uG83MIEi23AbsDxtVRnuF6YA9PN6KKKK//9k=\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x1509F00D0>\" />        </td><td style=\"text-align: right;\">                              7</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               3</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+u38DeC9I8Y6dqMMviCLTtaiIa2huMCKVMc5b1zgcdPQ545rXtA1Lwzq82l6tbNb3UXVTyGHZge4PrWbRXoGhfCHX/Eml2moaTf6PcRzqGaNbv95BntIu3g+wzVj4qanYPaeHvD8OqHV9Q0aCSC81DbgOSwwgPfbgjP8AXNeb1u+DdN0vV/F2nWWtX0VlpskmZ5pZBGoUAnG48DOMZ967fQtF8LeBteg1vWPGNndzWkwlt7TQmNxvI5wzkAAdvf1rz/xJqkeueJ9U1WGEwx3l1JOsbHJUMxIB/Osuiiiiv//Z\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x1509F0FA0>\" />                                                </td><td style=\"text-align: right;\">                              2</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               9</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+trw74S17xZcSQ6Hpst48QBkKlVVc9MsxAH50/wAS+DPEHhB7dNd05rQ3ClosyI4YDrypOOo4NYVFe5/BrxBa6n4XuvAtvdXWk6tcSvPHf2sYJZPlJ5zkNgEZ6Yx36r8e9QsrHRtA8KR3sl9fWf72eaaTfKBtAG8+rZJ/AV4XU1raz313DaWsTzXEziOONBksxOABX0BqEWmfA74fmG3ZZfFuqxFPPXBMZxyRnoi9uOTjPt4Df393ql9Ne31xJcXUzbpJZGyzH3NV6mtLuewvYLy2k8u4t5FlifAO1lOQcHjqKu6/4i1bxPqjalrN493dsoXeyhQABgAKoAA+g96zKK//2Q==\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x1509F1180>\" /></td><td style=\"text-align: right;\">                              8</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               4</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+ug8LeFLjxLPcSNcJZaZZp5t7fyqSkCduB1Y9l6msa9it4b6eK1uDc26SMsU5jKeYoPDbTyMjnFQUV6H4s3eHvBXh/wAJWSyGXUIU1W/IQEzPIP3SDvhQD06k+tW7Dw9Z+BPBV1r3iW3P9t6nA9vpVg4w0asu1pnB6cMcAjjA7njzGivcPAfxA1XSPh/e69rN5FeQaViw0u2aJDI0hUEBnxuCAbenp7YryHXtf1PxLq0up6tdyXNzIerHhBkkKo7KMnAFZtFFFFf/2Q==\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x1509F1A80>\" />                                                                            </td><td style=\"text-align: right;\">                              8</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               2</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+tfSPC+ta9aXdzpVi92loAZliZS4B7hM7m/AGskgqSCCCOCDSVc0nS7vW9VttNsIjLc3D7EXOPqSewAySewFd/wCFfCvhvUtWW20vxJqtvqdihuZdU+yolpDsGSd2/cozwGPX0rnviNZ3tv451O4vLVYFvZmuoGRgySxsSVdWHBB68d81ytW9M1G40nUre/tSomgfcAwyreoI7gjII7gmt3W/F8d5p0mlaJpMGi6ZNJ5s8MEjO87ZyA7tyVU52rwB7nmpfD2tWOo6cvhfxJKV05nLWV91bT5T394mONy/iORzz+saTeaFq1zpl/H5d1bvtdQcj1BB7gggg9wRVKiiiv/Z\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x1509F1FC0>\" />    </td><td style=\"text-align: right;\">                              9</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               7</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+tXTvDmqarbG5toYktw4jE1zcR28bMf4Q0jKGPsCTVO/sLvS7+exvoHguoHKSxOMFWHaq1Fa/h+zn1zX9J0os8kT3CxhGchUQtlz/ALIxkkj60viu/GqeLNVvFcPHJdP5bAY/dg4T/wAdArHorpvB6zWh1bXIsA6bZOULIWBkl/dAcd8O7f8AAT6VzNFFXrTWNQsdMvtOtrp4rS/CC6iUDEoQkrk9eCT0qjRX/9k=\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x1509F24A0>\" />                                                                                                                    </td><td style=\"text-align: right;\">                              9</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               5</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+nwwy3EyQwxvLK5wqIpZmPoAOtdO/w48WRadNfTaT5MUMDXMqTXESSpGDgsYmYOBx6c9q5WivTrm/g+HdpocdhpjSLqVkl1daoshjmuFcZaKGTafKC9CVBJ46VzOveNrrVrKTTrGxttK0yRlaS3tgS05HQyyNlnI9zj2rl6nsrO41G9gsrSIy3M7iOKNerMTgCu/+Hj32uRS+HdXtUufC0G+W6muCU/s7jmSOT+BuPunIY547157OsaXEiwuXiDkI5GNwzwcVHTo5HhlSWJ2SRCGVlOCpHQg1t6x4z8Ra/Zraapqs9xbqwbyzhQzAYBbAG447nJrCor//2Q==\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x1509F2B60>\" />            </td><td style=\"text-align: right;\">                              8</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               6</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+nRxvNIscaM8jkKqqMliegAr0bTfg5qn9npqHiXVdO8N20gzGL+QCVh/uZGPoSD7Vwus2MGmaxdWVtfwahDC+1bqDOyQeoz/n0yOao17p4J8ISeCPAn/CeSaNPq+uTxh9OtY4i626MOJWA56c57DA4ySPH9f8Qar4l1aXUdXu5Li5c9W4CD+6o6AD0FZdFfVElxqFzr2leLU8XQ6b4Ihs4pDCJdvmFRyhTHckA9TxgDpXzn4x1a013xjq2qWEHk2t1ctJGhGDg9yOxPX8aw6KKKK//9k=\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x1509F2C80>\" />                                                                </td><td style=\"text-align: right;\">                              5</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               4</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+rFhZy6hqNtZQRtJNcSrEiL1YscACtvx7YaTpXjnVtP0NXWwtZvJRXYsQygB+TyfmDVzlFes/DLwVrVjZ3fjWTSLmU2luTpVv5JZrid/lRwvXYu7dnGO/Y1ifEz4fXXgu7sWxe3ST2ySXd48ZMX2li25VfGO2cHmuBre8Gf2Ivi/TX8RzCLSY5fMuCY2cMFBIUhQSQSADx0NehzfGSK5+JC63LFcro2nW8sem2MfAL7SFZ1BA5P1xxjpXA+JPHfiTxZJJ/a2qzywM+8WytthXnIAQccds5PvXOUUUUV//2Q==\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x1509F3220>\" />                                                </td><td style=\"text-align: right;\">                              2</td></tr>\n",
       "<tr><td>[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...</td><td style=\"text-align: right;\">               2</td><td>&ltnp.ndarray<br>shape=(28, 28)<br>dtype=int64&gt</td><td><img style=\"max-height:128px;width:auto\" src=\"data:image/png;base64, /9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APAFVnYKoJYnAAHJNeoad8F7mHT4L7xd4i0zwzFcKWhiu3DTNjkgqWUZxjgEkZ5APFch408Jy+D9eOnm7jvbaSNZra8iXCTxn+Ickdcjgnp1rna2vCOoWGleL9Jv9Uh86xt7lJJk27vlB6474649qXWrzVfFXie7ne5uNYu5ZWCSxxsTIoPG1MZVcdFxxmur+JSHSPDXgrwxN5n2zT9Pe5uFkBDRvcMH8s+m3bjFec0V9WprmhfDDwRbX9rLp7aVNYRCxjhgK3V/PtyXdt2Mc5PHGTz0B8W8U+NvD3jrS/t2t6Zc2PiiIOon01FNvcjA2eYHbcu3AHBbv6gDzuiiiiv/2Q==\" alt=\"<PIL.Image.Image image mode=L size=28x28 at 0x1509F37F0>\" />    </td><td style=\"text-align: right;\">                              8</td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    <small>(Showing first 10 rows)</small>\n",
       "</div>"
      ],
      "text/plain": [
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| image                |   label | image_2d             | pil_image        |   model_classification |\n",
       "| List[Int64]          |   Int64 | Python               | Python           |                  Int64 |\n",
       "+======================+=========+======================+==================+========================+\n",
       "| [0, 0, 0, 0, 0, 0,   |       2 | [[  0   0   0   0    | <PIL.Image.Image |                      7 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x1509F00D0>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       3 | [[  0   0   0   0    | <PIL.Image.Image |                      2 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x1509F0FA0>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       9 | [[  0   0   0   0    | <PIL.Image.Image |                      8 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x1509F1180>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       4 | [[  0   0   0   0    | <PIL.Image.Image |                      8 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x1509F1A80>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       2 | [[  0   0   0   0    | <PIL.Image.Image |                      9 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x1509F1FC0>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       7 | [[  0   0   0   0    | <PIL.Image.Image |                      9 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x1509F24A0>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       5 | [[  0   0   0   0    | <PIL.Image.Image |                      8 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x1509F2B60>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       6 | [[  0   0   0   0    | <PIL.Image.Image |                      5 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x1509F2C80>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       4 | [[  0   0   0   0    | <PIL.Image.Image |                      2 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x1509F3220>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "| [0, 0, 0, 0, 0, 0,   |       2 | [[  0   0   0   0    | <PIL.Image.Image |                      8 |\n",
       "| 0, 0, 0, 0, 0, 0, 0, |         | 0   0   0   0   0    | image mode=L     |                        |\n",
       "| 0, 0, 0, 0, 0, 0,... |         | 0   0   0   0   0... | size=28x28 at    |                        |\n",
       "|                      |         |                      | 0x1509F37F0>     |                        |\n",
       "+----------------------+---------+----------------------+------------------+------------------------+\n",
       "(Showing first 10 rows)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classified_images_df.where(col(\"label\") != col(\"model_classification\")).show(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb7ca72b-0743-451d-a3bf-e492a73ad7d6",
   "metadata": {},
   "source": [
    "Some of these look hard indeed, even for a human!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5482e99e-cf3a-4d54-93e3-6e468db03eef",
   "metadata": {},
   "source": [
    "## Analytics\n",
    "\n",
    "We just managed to run our model, but how well did it actually do? Dataframes expose a powerful set of operations in Groupbys/Aggregations to help us report on aggregates of our data.\n",
    "\n",
    "Let's group our data by the true labels and calculate how many mistakes our model made per label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8b60eef9-eeab-435e-9f5d-c775af9afe3f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "    <table class=\"dataframe\">\n",
       "<thead>\n",
       "<tr><th style=\"text-align: right;\">  label<br>Int64</th><th style=\"text-align: right;\">  num_rows<br>UInt64</th><th style=\"text-align: right;\">  correct<br>Int64</th><th style=\"text-align: right;\">  wrong<br>Int64</th></tr>\n",
       "</thead>\n",
       "<tbody>\n",
       "<tr><td style=\"text-align: right;\">               0</td><td style=\"text-align: right;\">                 980</td><td style=\"text-align: right;\">               957</td><td style=\"text-align: right;\">              23</td></tr>\n",
       "<tr><td style=\"text-align: right;\">               1</td><td style=\"text-align: right;\">                1135</td><td style=\"text-align: right;\">              1123</td><td style=\"text-align: right;\">              12</td></tr>\n",
       "<tr><td style=\"text-align: right;\">               2</td><td style=\"text-align: right;\">                1032</td><td style=\"text-align: right;\">               996</td><td style=\"text-align: right;\">              36</td></tr>\n",
       "<tr><td style=\"text-align: right;\">               3</td><td style=\"text-align: right;\">                1010</td><td style=\"text-align: right;\">               965</td><td style=\"text-align: right;\">              45</td></tr>\n",
       "<tr><td style=\"text-align: right;\">               4</td><td style=\"text-align: right;\">                 982</td><td style=\"text-align: right;\">               951</td><td style=\"text-align: right;\">              31</td></tr>\n",
       "<tr><td style=\"text-align: right;\">               5</td><td style=\"text-align: right;\">                 892</td><td style=\"text-align: right;\">               830</td><td style=\"text-align: right;\">              62</td></tr>\n",
       "<tr><td style=\"text-align: right;\">               6</td><td style=\"text-align: right;\">                 958</td><td style=\"text-align: right;\">               925</td><td style=\"text-align: right;\">              33</td></tr>\n",
       "<tr><td style=\"text-align: right;\">               7</td><td style=\"text-align: right;\">                1028</td><td style=\"text-align: right;\">               971</td><td style=\"text-align: right;\">              57</td></tr>\n",
       "</tbody>\n",
       "</table>\n",
       "    <small>(Showing first 8 rows)</small>\n",
       "</div>"
      ],
      "text/plain": [
       "+---------+------------+-----------+---------+\n",
       "|   label |   num_rows |   correct |   wrong |\n",
       "|   Int64 |     UInt64 |     Int64 |   Int64 |\n",
       "+=========+============+===========+=========+\n",
       "|       0 |        980 |       957 |      23 |\n",
       "+---------+------------+-----------+---------+\n",
       "|       1 |       1135 |      1123 |      12 |\n",
       "+---------+------------+-----------+---------+\n",
       "|       2 |       1032 |       996 |      36 |\n",
       "+---------+------------+-----------+---------+\n",
       "|       3 |       1010 |       965 |      45 |\n",
       "+---------+------------+-----------+---------+\n",
       "|       4 |        982 |       951 |      31 |\n",
       "+---------+------------+-----------+---------+\n",
       "|       5 |        892 |       830 |      62 |\n",
       "+---------+------------+-----------+---------+\n",
       "|       6 |        958 |       925 |      33 |\n",
       "+---------+------------+-----------+---------+\n",
       "|       7 |       1028 |       971 |      57 |\n",
       "+---------+------------+-----------+---------+\n",
       "(Showing first 8 rows)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "analysis_df = (\n",
    "    classified_images_df.with_column(\"correct\", (col(\"model_classification\") == col(\"label\")).cast(DataType.int64()))\n",
    "    .with_column(\"wrong\", (col(\"model_classification\") != col(\"label\")).cast(DataType.int64()))\n",
    "    .groupby(col(\"label\"))\n",
    "    .agg(\n",
    "        col(\"label\").count().alias(\"num_rows\"),\n",
    "        col(\"correct\").sum(),\n",
    "        col(\"wrong\").sum(),\n",
    "    )\n",
    "    .sort(col(\"label\"))\n",
    ")\n",
    "\n",
    "analysis_df.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05f7df20-6dbc-4115-9acf-8d863cac93af",
   "metadata": {},
   "source": [
    "Pretty impressive, given that the model only actually trained for one epoch!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4acf0191-8bb2-4c50-9d19-7a6bc97840d2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "e5d77f7bd5a748e4f6412a25f9708ab7af36936de941fc795d1a6b75eb2da082"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
